Abstract
Background In the report naomi-simple_fit with parameter tmbstan = TRUE, we used the NUTS algorithm to perform MCMC inference for the simplified Naomi model.
Task Here we assess whether or not the results of the MCMC are suitable using a range of diagnostic tools.
We start by obtaining results from the latest version of naomi-simple_fit with tmbstan = TRUE.
out <- readRDS("depends/out.rds")
mcmc <- out$mcmc$stanfit
This MCMC took 3.28 days to run
cbpalette <- c("#56B4E9", "#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
bayesplot::color_scheme_set("viridisA")
ggplot2::theme_set(theme_minimal())
We are looking for values of \(\hat R\) less than 1.1 here.
rhats <- bayesplot::rhat(mcmc)
rhat_plot <- mcmc_rhat_data(rhats) %>%
ggplot(mapping = aes_(x = ~value, y = ~parameter, col = ~description)) +
geom_segment(mapping = aes_(yend = ~parameter, xend = ifelse(min(data$value) < 1, 1, -Inf)), na.rm = TRUE, alpha = 0.7) +
scale_color_manual(values = "#E69F00") +
geom_vline(xintercept = 1.05, linetype = "dashed", col = "grey40") +
labs(x = "Potential scale reduction factor", y = "NUTS parameter", col = "") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank()
)
rhat_plot
pdf("rhat.pdf", h = 3, w = 6.25)
rhat_plot
dev.off()
## quartz_off_screen
## 2
(big_rhats <- rhats[rhats > 1.1])
## named numeric(0)
length(big_rhats) / length(rhats)
## [1] 0
Reasonable to be worried about values less than 0.1 here.
ratios <- bayesplot::neff_ratio(mcmc)
breaks <- c(0, 0.1, 0.25, 0.5, 0.75, 1)
ratio_plot <- mcmc_neff_data(ratios) %>%
ggplot(mapping = aes_(x = ~value, y = ~parameter, color = ~description)) +
geom_segment(aes_(yend = ~parameter, xend = -Inf), na.rm = TRUE, alpha = 0.7) +
scale_color_manual(values = c("#56B4E9", "#009E73", "#E69F00")) +
geom_vline(xintercept = 0.1, linetype = "dashed", col = "grey40") +
geom_vline(xintercept = 0.5, linetype = "dashed", col = "grey40") +
geom_vline(xintercept = 1, linetype = "dashed", col = "grey40") +
labs(x = "ESS ratio", y = "NUTS parameter", col = "") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank()
)
ratio_plot
pdf("ratio.pdf", h = 3, w = 6.25)
ratio_plot
dev.off()
## quartz_off_screen
## 2
What are the total effective sample sizes?
#' I think that this $summary should be all of the chains grouped together
mcmc_summary <- summary(mcmc)$summary
ess_plot <- data.frame(mcmc_summary) %>%
tibble::rownames_to_column("param") %>%
ggplot(aes(x = n_eff)) +
geom_histogram(alpha = 0.8) +
labs(x = "ESS", y = "Count")
ess_plot
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
pdf("ess.pdf", h = 3, w = 6.25)
ess_plot
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
dev.off()
## quartz_off_screen
## 2
How much autocorrelation is there in the chains?
bayesplot::mcmc_acf(mcmc, pars = vars(starts_with("beta")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("beta")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("logit")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_sigma")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_rho_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_a[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_as[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_alpha_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_a[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_as[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_xa[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_lambda_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_anc_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_anc_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_or_gamma["))) #' N.B. these are from the ANC attendance model
There is a prior suspicion (from Jeff, Tim, Rachel) that the ART attendance model is unidentifiable.
Let’s have a look at the pairs plot for neighbouring districts and the log_or_gamma parameter.
area_merged <- sf::read_sf(system.file("extdata/demo_areas.geojson", package = "naomi"))
nb <- area_merged %>%
filter(area_level == max(area_level)) %>%
bsae::sf_to_nb()
neighbours_pairs_plot <- function(par, i) {
neighbour_pars <- paste0(par, "[", c(i, nb[[i]]), "]")
bayesplot::mcmc_pairs(mcmc, pars = neighbour_pars, diag_fun = "hist", off_diag_fun = "hex")
}
# area_merged %>%
# filter(area_level == max(area_level)) %>%
# print(n = Inf)
Here are Nkhata Bay and neighbours:
neighbours_pairs_plot("log_or_gamma", 5)
And here are Blantyre and neighbours:
neighbours_pairs_plot("log_or_gamma", 26)
np <- bayesplot::nuts_params(mcmc)
saveRDS(np, "nuts-params.rds")
Are there any divergent transitions?
np %>%
filter(Parameter == "divergent__") %>%
summarise(n_divergent = sum(Value))
bayesplot::mcmc_nuts_divergence(np, bayesplot::log_posterior(mcmc))
We can also use energy plots (Betancourt 2017): ideally these two histograms would be the same When the histograms are quite different, it may suggest the chains are not fully exploring the tails of the target distribution.
bayesplot::mcmc_nuts_energy(np)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 13.3.1
##
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] patchwork_1.1.2 tibble_3.2.1 tidyverse_1.3.1 rmarkdown_2.18 multi.utils_0.1.0
## [6] sf_1.0-9 bayesplot_1.9.0 rstan_2.21.5 StanHeaders_2.21.0-7 Matrix_1.5-4.1
## [11] stringr_1.5.0 purrr_1.0.1 tidyr_1.2.1 readr_2.1.3 ggplot2_3.4.0
## [16] forcats_0.5.2 dplyr_1.0.10
##
## loaded via a namespace (and not attached):
## [1] colorspace_2.0-3 deldir_1.0-6 ellipsis_0.3.2 class_7.3-20 ggridges_0.5.3
## [6] fs_1.6.1 rstudioapi_0.14 proxy_0.4-27 farver_2.1.1 hexbin_1.28.2
## [11] bit64_4.0.5 lubridate_1.8.0 fansi_1.0.4 mvtnorm_1.1-3 xml2_1.3.3
## [16] codetools_0.2-18 cachem_1.0.6 knitr_1.41 jsonlite_1.8.4 gt_0.8.0
## [21] broom_0.8.0 naomi_2.8.5 dbplyr_2.1.1 tmbstan_1.0.4 httr_1.4.4
## [26] compiler_4.2.0 backports_1.4.1 assertthat_0.2.1 fastmap_1.1.0 cli_3.6.1
## [31] s2_1.1.1 htmltools_0.5.3 prettyunits_1.1.1 tools_4.2.0 gtable_0.3.1
## [36] glue_1.6.2 reshape2_1.4.4 posterior_1.2.2 wk_0.7.0 V8_4.2.2
## [41] Rcpp_1.0.10 cellranger_1.1.0 jquerylib_0.1.4 vctrs_0.6.1 spdep_1.2-7
## [46] mvQuad_1.0-6 tensorA_0.36.2 xfun_0.37 ps_1.7.3 rvest_1.0.2
## [51] bsae_0.2.7 lifecycle_1.0.3 statmod_1.4.36 orderly_1.4.3 ids_1.0.1
## [56] scales_1.2.1 hms_1.1.2 parallel_4.2.0 inline_0.3.19 TMB_1.9.2
## [61] yaml_2.3.7 curl_5.0.0 memoise_2.0.1 gridExtra_2.3 loo_2.5.1
## [66] sass_0.4.4 stringi_1.7.8 RSQLite_2.2.14 highr_0.9 e1071_1.7-12
## [71] checkmate_2.1.0 boot_1.3-28 pkgbuild_1.3.1 spData_2.2.1 rlang_1.1.0
## [76] pkgconfig_2.0.3 matrixStats_0.62.0 distributional_0.3.0 evaluate_0.20 lattice_0.20-45
## [81] aghq_0.4.1 labeling_0.4.2 bit_4.0.5 processx_3.8.0 tidyselect_1.2.0
## [86] traduire_0.0.6 plyr_1.8.8 magrittr_2.0.3 bookdown_0.26 R6_2.5.1
## [91] generics_0.1.3 DBI_1.1.3 haven_2.5.0 pillar_1.9.0 withr_2.5.0
## [96] units_0.8-0 abind_1.4-5 sp_1.5-1 modelr_0.1.8 crayon_1.5.2
## [101] uuid_1.1-0 KernSmooth_2.23-20 utf8_1.2.3 tzdb_0.3.0 readxl_1.4.0
## [106] grid_4.2.0 data.table_1.14.6 blob_1.2.3 callr_3.7.3 reprex_2.0.1
## [111] digest_0.6.31 classInt_0.4-8 numDeriv_2016.8-1.1 openssl_2.0.5 RcppParallel_5.1.5
## [116] stats4_4.2.0 munsell_0.5.0 bslib_0.4.1 askpass_1.1